{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "db4d7066",
   "metadata": {},
   "source": [
    "# Creating a Custom Data Source\n",
    "\n",
    "**PyBroker** comes with pre-built [DataSources](https://www.pybroker.com/en/latest/reference/pybroker.data.html#pybroker.data.DataSource) for [Yahoo Finance](https://www.pybroker.com/en/latest/reference/pybroker.data.html#pybroker.data.YFinance), [Alpaca](https://www.pybroker.com/en/latest/reference/pybroker.data.html#pybroker.data.Alpaca), and [AKShare](https://github.com/akfamily/akshare), which you can use right away without any additional setup. But if you have a specific need or want to use a different data source, **PyBroker** also allows you to create your own ```DataSource``` class.\n",
    "\n",
    "\n",
    "## Extending DataSource\n",
    "\n",
    "In the example code provided below, a new ```DataSource``` called ```CSVDataSource``` is implemented, which loads data from a CSV file. The ```CSVDataSource``` reads a file named ```prices.csv``` into a [Pandas DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html), and then returns the data from this DataFrame based on the input parameters provided:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7e59aa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import pybroker\n",
    "from pybroker.data import DataSource\n",
    "\n",
    "class CSVDataSource(DataSource):\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        # Register custom columns in the CSV.\n",
    "        pybroker.register_columns('rsi')\n",
    "    \n",
    "    def _fetch_data(self, symbols, start_date, end_date, _timeframe, _adjust):\n",
    "        df = pd.read_csv('data/prices.csv')\n",
    "        df = df[df['symbol'].isin(symbols)]\n",
    "        df['date'] = pd.to_datetime(df['date'])\n",
    "        return df[(df['date'] >= start_date) & (df['date'] <= end_date)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3abf367",
   "metadata": {},
   "source": [
    "To make the custom ```'rsi'``` column from the CSV file available to **PyBroker**, we register it using [pybroker.register_columns](https://www.pybroker.com/en/latest/reference/pybroker.scope.html#pybroker.scope.register_columns). This allows **PyBroker** to use this custom column when it processes the data.\n",
    "\n",
    "It's important to note that when returning the data from your custom DataSource, it must include the following columns: ```symbol```, ```date```, ```open```, ```high```, ```low```, and ```close```, as these columns are expected by **PyBroker**.\n",
    "\n",
    "Now we can query the CSV data from an instance of ```CSVDataSource```:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "08a7e845",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading bar data...\n",
      "Loaded bar data: 0:00:00 \n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>symbol</th>\n",
       "      <th>open</th>\n",
       "      <th>high</th>\n",
       "      <th>low</th>\n",
       "      <th>close</th>\n",
       "      <th>rsi</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2021-06-01</td>\n",
       "      <td>DIS</td>\n",
       "      <td>180.179993</td>\n",
       "      <td>181.009995</td>\n",
       "      <td>178.740005</td>\n",
       "      <td>178.839996</td>\n",
       "      <td>46.321532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2021-06-01</td>\n",
       "      <td>MCD</td>\n",
       "      <td>235.979996</td>\n",
       "      <td>235.990005</td>\n",
       "      <td>232.740005</td>\n",
       "      <td>233.240005</td>\n",
       "      <td>46.522926</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2021-06-01</td>\n",
       "      <td>NKE</td>\n",
       "      <td>137.850006</td>\n",
       "      <td>138.050003</td>\n",
       "      <td>134.210007</td>\n",
       "      <td>134.509995</td>\n",
       "      <td>53.308085</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2021-06-02</td>\n",
       "      <td>DIS</td>\n",
       "      <td>179.039993</td>\n",
       "      <td>179.100006</td>\n",
       "      <td>176.929993</td>\n",
       "      <td>177.000000</td>\n",
       "      <td>42.635256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2021-06-02</td>\n",
       "      <td>MCD</td>\n",
       "      <td>233.970001</td>\n",
       "      <td>234.330002</td>\n",
       "      <td>232.809998</td>\n",
       "      <td>233.779999</td>\n",
       "      <td>48.051484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>382</th>\n",
       "      <td>2021-11-30</td>\n",
       "      <td>MCD</td>\n",
       "      <td>247.380005</td>\n",
       "      <td>247.899994</td>\n",
       "      <td>243.949997</td>\n",
       "      <td>244.600006</td>\n",
       "      <td>40.461178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>383</th>\n",
       "      <td>2021-11-30</td>\n",
       "      <td>NKE</td>\n",
       "      <td>168.789993</td>\n",
       "      <td>171.550003</td>\n",
       "      <td>167.529999</td>\n",
       "      <td>169.240005</td>\n",
       "      <td>51.505558</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>384</th>\n",
       "      <td>2021-12-01</td>\n",
       "      <td>DIS</td>\n",
       "      <td>146.699997</td>\n",
       "      <td>148.369995</td>\n",
       "      <td>142.039993</td>\n",
       "      <td>142.149994</td>\n",
       "      <td>16.677555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>385</th>\n",
       "      <td>2021-12-01</td>\n",
       "      <td>MCD</td>\n",
       "      <td>245.759995</td>\n",
       "      <td>250.899994</td>\n",
       "      <td>244.110001</td>\n",
       "      <td>244.179993</td>\n",
       "      <td>39.853689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>386</th>\n",
       "      <td>2021-12-01</td>\n",
       "      <td>NKE</td>\n",
       "      <td>170.889999</td>\n",
       "      <td>173.369995</td>\n",
       "      <td>166.679993</td>\n",
       "      <td>166.699997</td>\n",
       "      <td>46.704527</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>387 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "          date symbol        open        high         low       close  \\\n",
       "0   2021-06-01    DIS  180.179993  181.009995  178.740005  178.839996   \n",
       "1   2021-06-01    MCD  235.979996  235.990005  232.740005  233.240005   \n",
       "2   2021-06-01    NKE  137.850006  138.050003  134.210007  134.509995   \n",
       "3   2021-06-02    DIS  179.039993  179.100006  176.929993  177.000000   \n",
       "4   2021-06-02    MCD  233.970001  234.330002  232.809998  233.779999   \n",
       "..         ...    ...         ...         ...         ...         ...   \n",
       "382 2021-11-30    MCD  247.380005  247.899994  243.949997  244.600006   \n",
       "383 2021-11-30    NKE  168.789993  171.550003  167.529999  169.240005   \n",
       "384 2021-12-01    DIS  146.699997  148.369995  142.039993  142.149994   \n",
       "385 2021-12-01    MCD  245.759995  250.899994  244.110001  244.179993   \n",
       "386 2021-12-01    NKE  170.889999  173.369995  166.679993  166.699997   \n",
       "\n",
       "           rsi  \n",
       "0    46.321532  \n",
       "1    46.522926  \n",
       "2    53.308085  \n",
       "3    42.635256  \n",
       "4    48.051484  \n",
       "..         ...  \n",
       "382  40.461178  \n",
       "383  51.505558  \n",
       "384  16.677555  \n",
       "385  39.853689  \n",
       "386  46.704527  \n",
       "\n",
       "[387 rows x 7 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "csv_data_source = CSVDataSource()\n",
    "df = csv_data_source.query(['MCD', 'NKE', 'DIS'], '6/1/2021', '12/1/2021')\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a21ca3b",
   "metadata": {},
   "source": [
    "To use ```CSVDataSource``` in a backtest, we create a new [Strategy](https://www.pybroker.com/en/latest/reference/pybroker.strategy.html#pybroker.strategy.Strategy) object and pass the custom ```DataSource```:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1238ecd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2021-06-01 00:00:00 to 2021-12-01 00:00:00\n",
      "\n",
      "Loading bar data...\n",
      "Loaded bar data: 0:00:00 \n",
      "\n",
      "Test split: 2021-06-01 00:00:00 to 2021-12-01 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (129 of 129) |######################| Elapsed Time: 0:00:00 Time:  0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished backtest: 0:00:02\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>type</th>\n",
       "      <th>symbol</th>\n",
       "      <th>date</th>\n",
       "      <th>shares</th>\n",
       "      <th>limit_price</th>\n",
       "      <th>fill_price</th>\n",
       "      <th>fees</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buy</td>\n",
       "      <td>NKE</td>\n",
       "      <td>2021-09-21</td>\n",
       "      <td>100</td>\n",
       "      <td>NaN</td>\n",
       "      <td>154.86</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>sell</td>\n",
       "      <td>NKE</td>\n",
       "      <td>2021-11-04</td>\n",
       "      <td>100</td>\n",
       "      <td>NaN</td>\n",
       "      <td>173.82</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>buy</td>\n",
       "      <td>DIS</td>\n",
       "      <td>2021-11-16</td>\n",
       "      <td>100</td>\n",
       "      <td>NaN</td>\n",
       "      <td>159.40</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    type symbol       date  shares  limit_price  fill_price  fees\n",
       "id                                                               \n",
       "1    buy    NKE 2021-09-21     100          NaN      154.86   0.0\n",
       "2   sell    NKE 2021-11-04     100          NaN      173.82   0.0\n",
       "3    buy    DIS 2021-11-16     100          NaN      159.40   0.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pybroker import Strategy\n",
    "\n",
    "def buy_low_sell_high_rsi(ctx):\n",
    "    pos = ctx.long_pos() \n",
    "    if not pos and ctx.rsi[-1] < 30:\n",
    "        ctx.buy_shares = 100\n",
    "    elif pos and ctx.rsi[-1] > 70:\n",
    "        ctx.sell_shares = pos.shares\n",
    "\n",
    "strategy = Strategy(csv_data_source, '6/1/2021', '12/1/2021')\n",
    "strategy.add_execution(buy_low_sell_high_rsi, ['MCD', 'NKE', 'DIS'])\n",
    "result = strategy.backtest()\n",
    "result.orders"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3c94d73",
   "metadata": {},
   "source": [
    "Note that because we registered the custom ```rsi``` column with **PyBroker**, it can be accessed in the [ExecContext](https://www.pybroker.com/en/latest/reference/pybroker.context.html#pybroker.context.ExecContext) using ```ctx.rsi```."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0eb8e5d",
   "metadata": {},
   "source": [
    "## Using a Pandas DataFrame\n",
    "\n",
    "If you do not need the flexibility of implementing your own [DataSource](https://www.pybroker.com/en/latest/reference/pybroker.data.html#pybroker.data.DataSource), then you can pass a [Pandas DataFrame](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.html) to a ``Strategy`` instead.\n",
    "\n",
    "To demonstrate, the earlier example can be re-implemented as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eecb37d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Backtesting: 2021-06-01 00:00:00 to 2021-12-01 00:00:00\n",
      "\n",
      "Test split: 2021-06-01 00:00:00 to 2021-12-01 00:00:00\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100% (129 of 129) |######################| Elapsed Time: 0:00:00 Time:  0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Finished backtest: 0:00:00\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>type</th>\n",
       "      <th>symbol</th>\n",
       "      <th>date</th>\n",
       "      <th>shares</th>\n",
       "      <th>limit_price</th>\n",
       "      <th>fill_price</th>\n",
       "      <th>fees</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>id</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>buy</td>\n",
       "      <td>NKE</td>\n",
       "      <td>2021-09-21</td>\n",
       "      <td>100</td>\n",
       "      <td>NaN</td>\n",
       "      <td>154.86</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>sell</td>\n",
       "      <td>NKE</td>\n",
       "      <td>2021-11-04</td>\n",
       "      <td>100</td>\n",
       "      <td>NaN</td>\n",
       "      <td>173.82</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>buy</td>\n",
       "      <td>DIS</td>\n",
       "      <td>2021-11-16</td>\n",
       "      <td>100</td>\n",
       "      <td>NaN</td>\n",
       "      <td>159.40</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    type symbol       date  shares  limit_price  fill_price  fees\n",
       "id                                                               \n",
       "1    buy    NKE 2021-09-21     100          NaN      154.86   0.0\n",
       "2   sell    NKE 2021-11-04     100          NaN      173.82   0.0\n",
       "3    buy    DIS 2021-11-16     100          NaN      159.40   0.0"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/prices.csv')\n",
    "df['date'] = pd.to_datetime(df['date'])\n",
    "\n",
    "pybroker.register_columns('rsi')\n",
    "\n",
    "strategy = Strategy(df, '6/1/2021', '12/1/2021')\n",
    "strategy.add_execution(buy_low_sell_high_rsi, ['MCD', 'NKE', 'DIS'])\n",
    "result = strategy.backtest()\n",
    "result.orders"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
